package ru.ifmo.ctddev.genetic.transducer.algorithm;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.PrintWriter;
import java.util.*;

import ru.ifmo.ctddev.genetic.transducer.scenario.Path;
import ru.ifmo.ctddev.genetic.transducer.scenario.TestGroup;
import ru.ifmo.ctddev.genetic.transducer.verifier.IVerifierFactory;
import ru.ifmo.ctddev.genetic.transducer.verifier.VerifierFactory;
import ru.ifmo.ltl.LtlParseException;

public class GeneticAlgorithmPlusHillClimber {
	private final double desiredFitness;
	private final int populationSize;
	private final int maxStates;
	
	private FST[] curGeneration;
	private Random random = new Random();
	
	/**
	 * ���� ������, ����������� � ��������� ���������.
	 * ��������������, ��������� ����� ���������� ��������� 
	 * ����� �������� � ������� ���������� �� ������������
	 * � ����� ���������.
	 */
	private final double partStay;
	
	/**
	 * �����������, � ������� ���������� �������.
	 */
	private final double mutationProb;
	
	/*
	 * ���������� ���������, �� ���������� �������� ��� ���������� ��������� ������-�������
	 * ���������� "�����" ������� ���������
	 */
	private final int timeSmallMutation;
	
	/*
	 * ���������� ���������, �� ���������� �������� ��� ���������� ��������� ������-�������
	 * ���������� "�������" ������� ���������
	 */
	private final int timeBigMutation;

	private final String[] setOfInputs;
	private final String[] setOfOutputs;
	
	private final List<TestGroup> groups;
    private final ArrayList<Path> tests;
	
	public GeneticAlgorithmPlusHillClimber(Parameters parameters,
              String[] setOfInputs, String[] setOfOutputs, List<TestGroup> groups) {
		this.populationSize = parameters.getPopulationSize();
		this.maxStates = parameters.getStateNumber();
		this.partStay = parameters.getPartStay();
		this.mutationProb = parameters.getMutationProbability();
		this.timeSmallMutation = parameters.getTimeSmallMutation();
		this.timeBigMutation = parameters.getTimeBigMutation();
		this.setOfInputs = setOfInputs;
		this.setOfOutputs = setOfOutputs;
		this.groups = groups;
        this.tests = new ArrayList<Path>();
        for (TestGroup g : groups) {
            tests.addAll(g.getTests());
        }

        IVerifierFactory verifier = new VerifierFactory(setOfInputs, setOfOutputs);
        try {
            verifier.prepareFormulas(groups.toArray(new TestGroup[groups.size()]));
        } catch (LtlParseException e) {
            throw new RuntimeException(e);   //TODO: do something less stupid
        }
        FST.fitnessCalculator = new FitnessCalculator(verifier, groups, tests);

        this.desiredFitness = FST.fitnessCalculator.evaluateDesiredFitness(groups) + parameters.getDesiredFitness();
	}

	ArrayList<Double> generations = new ArrayList<Double>();
	
	/**
	 * ������������� ��������� ���������� �������.
	 *
	 */
	private void init() {
		curGeneration = new FST[populationSize];
		for (int i = 0; i < populationSize; i++) {
			FST toAdd = FST.randomIndividual(maxStates, setOfInputs, setOfOutputs);
			curGeneration[i] = toAdd;
		}
	}
	
	public FST go() {
		init();
		int genCount = 0;
		
		double lastBest = -1;
		int cntBestEqual = 0;
		
		ArrayList<Double> meanFitness = new ArrayList<Double>();
		
		while (true) {
			
			double maxFitness = Integer.MIN_VALUE;
			double minFitness = Integer.MAX_VALUE;
			double sumFitness = 0;
			FST bestAutomaton = null;
			for (int i = 0; i < populationSize; i++) {
				sumFitness += curGeneration[i].getFitness();
				if (maxFitness < curGeneration[i].getFitness()) {
					maxFitness = curGeneration[i].getFitness();
					bestAutomaton = curGeneration[i];
				}
				minFitness = Math.min(minFitness, curGeneration[i].getFitness());
			}
			
			generations.add(maxFitness);
			meanFitness.add(sumFitness / populationSize);
			
			
			System.out.println("Generation " + genCount + " Max fitness = " + maxFitness + " Min fitness = " + minFitness + " Sum fitness = " + sumFitness);
			if (genCount % 200 == 0) {
				try {
					PrintWriter out = new PrintWriter(new File("gen" + genCount + ".xml"));
					FST.printAutomaton(out, bestAutomaton, tests, 0);
					out.close();
				} catch (FileNotFoundException e) {
					// TODO Auto-generated catch block
					e.printStackTrace();
				}
			}
			if (maxFitness >= 0.99 * desiredFitness) {
				/*
				try {
					PrintWriter out = new PrintWriter(new File("fitness_graph"));
					for (double x : generations) {
						out.println(x);
					}
					out.close();
				} catch (FileNotFoundException e) {
					// TODO Auto-generated catch block
					e.printStackTrace();
				}
				
				try {
					PrintWriter out = new PrintWriter(new File("mean_fitness_graph"));
					for (double x : meanFitness) {
						out.println(x);
					}
					out.close();
				} catch (FileNotFoundException e) {
					// TODO Auto-generated catch block
					e.printStackTrace();
				}
				*/
				break;
//				return bestAutomaton;
			}
			
			if (Math.abs(maxFitness - lastBest) < 1E-15) {
				cntBestEqual++;
			} else {
				lastBest = maxFitness;
				cntBestEqual = 0;
			}
			
			// ���������� ����� ���������
			FST[] nextGeneration = new FST[populationSize];
			int nextCnt = 0;
			
			// ��������� ����� ��������
			Arrays.sort(curGeneration, new Comparator<FST>() {
				public int compare(FST a1, FST a2) {
					return Double.compare(a2.getFitness(), a1.getFitness());
				}}
			);
			
			int cntStay = (int) (partStay * populationSize);
			if ((populationSize - cntStay) % 2 == 1) {
				cntStay++;
			}
			for (int i = 0; i < cntStay; i++) {
				nextGeneration[nextCnt++] = curGeneration[i];
			}
			
			FST[] best = new FST[cntStay];
			for (int i = 0; i < cntStay; i++) {
				best[i] = curGeneration[i];
			}
			
			boolean[] can = new boolean[populationSize];
			Arrays.fill(can, true);
			// ������ ������ ���������� (��� ���� ���������� ����)
			int cntAdd = populationSize - cntStay;
			for (int i = 0; i < cntAdd / 2; i++) {
				int num1 = random.nextInt(populationSize);
				int num2 = num1;
				while (num1 == num2) {
					num2 = random.nextInt(populationSize);
				}
				if (random.nextDouble() > Const.MUTATION_PROBABILITY) {
					FST[] toAdd = FST.crossOver(curGeneration[num1], curGeneration[num2], groups);
					for (int j = 0; j < 2; j++) {
						nextGeneration[nextCnt++] = toAdd[j];
					}
				} else {
					nextGeneration[nextCnt++] = curGeneration[num1].mutate();
					nextGeneration[nextCnt++] = curGeneration[num2].mutate();
				}
			}
			
			// ������ ������� 
			for (int i = 0; i < populationSize; i++) {
				if (random.nextDouble() < mutationProb) {
					nextGeneration[i] = nextGeneration[i].mutate();
				}
			}

			if (cntBestEqual >= timeSmallMutation) {
				// "�����" ������� ���������
				for (int i = populationSize / 10; i < populationSize; i++) {
					nextGeneration[i] = nextGeneration[i].mutate();
				}
			}

			if (cntBestEqual >= timeBigMutation) {
				// "�������" ������� ���������
				int start = 0;
				for (int i = start; i < populationSize; i++) {
					nextGeneration[i] = FST.randomIndividual(maxStates, setOfInputs, setOfOutputs);
				}
			}
			
			// ����� ��������� ������
			curGeneration = nextGeneration;
			genCount++;
		}
		
		Arrays.sort(curGeneration, new Comparator<FST>() {
			public int compare(FST a1, FST a2) {
				return Double.compare(a2.getFitness(), a1.getFitness());
			}}
		);
		for (FST fst : curGeneration) {
			FST current = fst;
			int cntNonImprove = 0;
			int itNumber = 1;
			while (true) {
				System.out.println("Iteration " + itNumber + " Fitness = " + current.getFitness());
				itNumber++;
				if (current.getFitness() >= desiredFitness) {
					return current;
				}
				FST candidate = current.mutate();
				if (candidate.getFitness() >= current.getFitness()) {
					cntNonImprove = 0;
					current = candidate;
				} else {
					cntNonImprove++;
				}
				if (cntNonImprove >= populationSize * timeSmallMutation) {
					break;
				}
			}
		}
		return null;
	}

    public List<Path> getTests() {
        return Collections.unmodifiableList(tests);
    }

    public double getDesiredFitness() {
        return desiredFitness;
    }

}

